{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Fine-tune the pretrained CHGNet for better accuracy\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    from chgnet.model import CHGNet\n",
    "except ImportError:\n",
    "    # install CHGNet (only needed on Google Colab or if you didn't install CHGNet yet)\n",
    "    !pip install chgnet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from pymatgen.core import Structure\n",
    "\n",
    "# If the above line fails in Google Colab due to numpy version issue,\n",
    "# please restart the runtime, and the problem will be solved"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3",
   "metadata": {},
   "source": [
    "## 0. Parse DFT outputs to CHGNet readable formats\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4",
   "metadata": {},
   "source": [
    "CHGNet is interfaced to [Pymatgen](https://pymatgen.org/), the training samples (normally coming from different DFTs like VASP),\n",
    "need to be converted to [pymatgen.core.structure](https://pymatgen.org/pymatgen.core.html#module-pymatgen.core.structure).\n",
    "\n",
    "To convert VASP calculation to pymatgen structures and CHGNet labels, you can use the following [code](https://github.com/CederGroupHub/chgnet/blob/main/chgnet/utils/vasp_utils.py):\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from chgnet.utils import parse_vasp_dir\n",
    "\n",
    "# ./my_vasp_calc_dir contains vasprun.xml OSZICAR etc.\n",
    "dataset_dict = parse_vasp_dir(\n",
    "    file_root=\"./my_vasp_calc_dir\", save_path=\"./my_vasp_calc_dir/chgnet_dataset.json\"\n",
    ")\n",
    "print(list(dataset_dict))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6",
   "metadata": {},
   "source": [
    "The parsed python dictionary includes information for CHGNet inputs (structures), and CHGNet prediction labels (energy, force, stress ,magmom).\n",
    "\n",
    "we can save the parsed structures and labels to disk, so that they can be easily reloaded during multiple rounds of training.\n",
    "\n",
    "The json file can be saved by providing the save_path\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7",
   "metadata": {},
   "source": [
    "The Pymatgen structures can be saved separately if you're interested to take a look into each structure.\n",
    "\n",
    "Below are the example codes to save the structures in either json, pickle, cif, or CHGNet graph.\n",
    "\n",
    "For super-large training dataset, like MPtrj dataset, we recommend [converting them to CHGNet graphs](https://github.com/CederGroupHub/chgnet/blob/main/examples/make_graphs.py). This will save significant memory and graph computing time.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Structure to json\n",
    "from chgnet.utils import write_json\n",
    "\n",
    "dict_to_json = [struct.as_dict() for struct in dataset_dict[\"structure\"]]\n",
    "write_json(dict_to_json, \"CHGNet_structures.json\")\n",
    "\n",
    "\n",
    "# Structure to pickle\n",
    "import pickle\n",
    "\n",
    "with open(\"CHGNet_structures.p\", \"wb\") as f:\n",
    "    pickle.dump(dataset_dict, f)\n",
    "\n",
    "\n",
    "# Structure to cif\n",
    "for idx, struct in enumerate(dataset_dict[\"structure\"]):\n",
    "    struct.to(filename=f\"{idx}.cif\")\n",
    "\n",
    "\n",
    "# Structure to CHGNet graph\n",
    "from chgnet.graph import CrystalGraphConverter\n",
    "\n",
    "converter = CrystalGraphConverter()\n",
    "for idx, struct in enumerate(dataset_dict[\"structure\"]):\n",
    "    graph = converter(struct)\n",
    "    graph.save(fname=f\"{idx}.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9",
   "metadata": {},
   "source": [
    "For other types of DFT calculations, please refer to their interfaces\n",
    "in [pymatgen.io](https://pymatgen.org/pymatgen.io.html#module-pymatgen.io).\n",
    "\n",
    "see: [Quantum Espresso](https://pymatgen.org/pymatgen.io.html#module-pymatgen.io.pwscf)\n",
    "\n",
    "see: [CP2K](https://pymatgen.org/pymatgen.io.cp2k.html#module-pymatgen.io.cp2k)\n",
    "\n",
    "see: [Gaussian](https://pymatgen.org/pymatgen.io.html#module-pymatgen.io.gaussian)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10",
   "metadata": {},
   "source": [
    "## 1. Prepare Training Data\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11",
   "metadata": {},
   "source": [
    "If you have parsed your VASP labels from step 0, you can reload the saved json file.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12",
   "metadata": {},
   "outputs": [],
   "source": [
    "from chgnet.utils import read_json\n",
    "\n",
    "dataset_dict = read_json(\"./my_vasp_calc_dir/chgnet_dataset.json\")\n",
    "structures = [Structure.from_dict(struct) for struct in dataset_dict[\"structure\"]]\n",
    "energies = dataset_dict[\"energy_per_atom\"]\n",
    "forces = dataset_dict[\"force\"]\n",
    "stresses = dataset_dict.get(\"stress\") or None\n",
    "magmoms = dataset_dict.get(\"magmom\") or None"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13",
   "metadata": {},
   "source": [
    "If you don't have any DFT calculations now, we can create a dummy fine-tuning dataset by using CHGNet prediction with some random noise.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14",
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    from chgnet import ROOT\n",
    "\n",
    "    lmo = Structure.from_file(f\"{ROOT}/examples/mp-18767-LiMnO2.cif\")\n",
    "except Exception:\n",
    "    from urllib.request import urlopen\n",
    "\n",
    "    url = \"https://raw.githubusercontent.com/CederGroupHub/chgnet/main/examples/mp-18767-LiMnO2.cif\"\n",
    "    cif = urlopen(url).read().decode(\"utf-8\")\n",
    "    lmo = Structure.from_str(cif, fmt=\"cif\")\n",
    "\n",
    "structures, energies_per_atom, forces, stresses, magmoms = [], [], [], [], []\n",
    "chgnet = CHGNet.load()\n",
    "for _ in range(100):\n",
    "    structure = lmo.copy()\n",
    "    # stretch the cell by a small amount\n",
    "    structure.apply_strain(np.random.uniform(-0.1, 0.1, size=3))\n",
    "    # perturb all atom positions by a small amount\n",
    "    structure.perturb(0.1)\n",
    "\n",
    "    pred = chgnet.predict_structure(structure)\n",
    "\n",
    "    structures.append(structure)\n",
    "    energies_per_atom.append(pred[\"e\"] + np.random.uniform(-0.1, 0.1, size=1))\n",
    "    forces.append(pred[\"f\"] + np.random.uniform(-0.01, 0.01, size=pred[\"f\"].shape))\n",
    "    stresses.append(\n",
    "        pred[\"s\"] * -10 + np.random.uniform(-0.05, 0.05, size=pred[\"s\"].shape)\n",
    "    )\n",
    "    magmoms.append(pred[\"m\"] + np.random.uniform(-0.03, 0.03, size=pred[\"m\"].shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15",
   "metadata": {},
   "source": [
    "Note that the stress output from CHGNet is in unit of GPa, here the -10 unit conversion\n",
    "modifies it to be kbar in VASP raw unit.\n",
    "If you're using stress labels from VASP, you don't need to do any unit conversions\n",
    "StructureData dataset class takes in VASP units.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16",
   "metadata": {},
   "source": [
    "## 2. Define DataSet\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17",
   "metadata": {},
   "outputs": [],
   "source": [
    "from chgnet.data.dataset import StructureData, get_train_val_test_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100 structures imported\n"
     ]
    }
   ],
   "source": [
    "dataset = StructureData(\n",
    "    structures=structures,\n",
    "    energies=energies_per_atom,\n",
    "    forces=forces,\n",
    "    stresses=stresses,  # can be None\n",
    "    magmoms=magmoms,  # can be None\n",
    ")\n",
    "train_loader, val_loader, test_loader = get_train_val_test_loader(\n",
    "    dataset, batch_size=8, train_ratio=0.9, val_ratio=0.05\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19",
   "metadata": {},
   "source": [
    "Alternatively, the dataset can be directly created from VASP calculation dir.\n",
    "This function essentially parse the VASP directory first, save the labels to json file, and create the StructureData class\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = StructureData.from_vasp(\n",
    "    file_root=\"./my_vasp_calc_dir\", save_path=\"./my_vasp_calc_dir/chgnet_dataset.json\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21",
   "metadata": {},
   "source": [
    "The training set is used to optimize the CHGNet through gradient descent, the validation set is used to see validation error at the end of each epoch, and the test set is used to see the final test error at the end of training. The test set can be optional.\n",
    "\n",
    "The `batch_size` is defined to be 8 for small GPU-memory. If > 10 GB memory is available, we highly recommend to increase `batch_size` for better speed.\n",
    "\n",
    "If you have very large numbers (>100K) of structures (which is typical for AIMD), putting them all in a python list can quickly run into memory issues. In this case we highly recommend you to pre-convert all the structures into graphs and save them as shown in `examples/make_graphs.py`. Then directly train CHGNet by loading the graphs from disk instead of memory using the `GraphData` class defined in `data/dataset.py`.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22",
   "metadata": {},
   "source": [
    "## 3. Define model and trainer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CHGNet v0.3.0 initialized with 412,525 parameters\n",
      "CHGNet will run on cpu\n"
     ]
    }
   ],
   "source": [
    "from chgnet.model import CHGNet\n",
    "from chgnet.trainer import Trainer\n",
    "\n",
    "# Load pretrained CHGNet\n",
    "chgnet = CHGNet.load()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24",
   "metadata": {},
   "source": [
    "It's optional to freeze the weights inside some layers. This is a common technique to retain the learned knowledge during fine-tuning in large pretrained neural networks. You can choose the layers you want to freeze.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optionally fix the weights of some layers\n",
    "for layer in [\n",
    "    chgnet.atom_embedding,\n",
    "    chgnet.bond_embedding,\n",
    "    chgnet.angle_embedding,\n",
    "    chgnet.bond_basis_expansion,\n",
    "    chgnet.angle_basis_expansion,\n",
    "    chgnet.atom_conv_layers[:-1],\n",
    "    chgnet.bond_conv_layers,\n",
    "    chgnet.angle_layers,\n",
    "]:\n",
    "    for param in layer.parameters():\n",
    "        param.requires_grad = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define Trainer\n",
    "trainer = Trainer(\n",
    "    model=chgnet,\n",
    "    targets=\"efsm\",\n",
    "    optimizer=\"Adam\",\n",
    "    scheduler=\"CosLR\",\n",
    "    criterion=\"MSE\",\n",
    "    epochs=5,\n",
    "    learning_rate=1e-2,\n",
    "    use_device=\"cpu\",\n",
    "    print_freq=6,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27",
   "metadata": {},
   "source": [
    "## 4. Start training\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Begin Training: using cpu device\n",
      "training targets: efsm\n",
      "Epoch: [0][1/12]\tTime (0.476)  Data (0.016)  Loss 0.0033 (0.0033)  MAEs:  e 0.053 (0.053)  f 0.004 (0.004)  s 0.002 (0.002)  m 0.016 (0.016)  \n",
      "Epoch: [0][6/12]\tTime (0.426)  Data (0.015)  Loss 0.0040 (0.0039)  MAEs:  e 0.054 (0.056)  f 0.005 (0.005)  s 0.002 (0.002)  m 0.015 (0.015)  \n",
      "Epoch: [0][12/12]\tTime (0.414)  Data (0.014)  Loss 0.0040 (0.0038)  MAEs:  e 0.054 (0.054)  f 0.005 (0.005)  s 0.002 (0.002)  m 0.015 (0.014)  \n",
      "*   e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n",
      "Epoch: [1][1/12]\tTime (0.409)  Data (0.000)  Loss 0.0052 (0.0052)  MAEs:  e 0.064 (0.064)  f 0.005 (0.005)  s 0.002 (0.002)  m 0.013 (0.013)  \n",
      "Epoch: [1][6/12]\tTime (0.393)  Data (0.000)  Loss 0.0036 (0.0039)  MAEs:  e 0.053 (0.055)  f 0.005 (0.005)  s 0.002 (0.002)  m 0.014 (0.014)  \n",
      "Epoch: [1][12/12]\tTime (0.371)  Data (0.000)  Loss 0.0029 (0.0038)  MAEs:  e 0.053 (0.054)  f 0.005 (0.005)  s 0.003 (0.002)  m 0.012 (0.014)  \n",
      "*   e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n",
      "Epoch: [2][1/12]\tTime (0.389)  Data (0.000)  Loss 0.0056 (0.0056)  MAEs:  e 0.065 (0.065)  f 0.005 (0.005)  s 0.002 (0.002)  m 0.015 (0.015)  \n",
      "Epoch: [2][6/12]\tTime (0.377)  Data (0.000)  Loss 0.0042 (0.0046)  MAEs:  e 0.059 (0.062)  f 0.005 (0.005)  s 0.002 (0.002)  m 0.014 (0.014)  \n",
      "Epoch: [2][12/12]\tTime (0.350)  Data (0.000)  Loss 0.0025 (0.0038)  MAEs:  e 0.048 (0.054)  f 0.005 (0.005)  s 0.002 (0.002)  m 0.008 (0.014)  \n",
      "*   e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n",
      "Epoch: [3][1/12]\tTime (0.363)  Data (0.000)  Loss 0.0049 (0.0049)  MAEs:  e 0.065 (0.065)  f 0.005 (0.005)  s 0.002 (0.002)  m 0.014 (0.014)  \n",
      "Epoch: [3][6/12]\tTime (0.359)  Data (0.000)  Loss 0.0050 (0.0042)  MAEs:  e 0.066 (0.057)  f 0.005 (0.005)  s 0.003 (0.002)  m 0.014 (0.014)  \n",
      "Epoch: [3][12/12]\tTime (0.355)  Data (0.000)  Loss 0.0045 (0.0038)  MAEs:  e 0.059 (0.054)  f 0.004 (0.005)  s 0.003 (0.002)  m 0.012 (0.014)  \n",
      "*   e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n",
      "Epoch: [4][1/12]\tTime (0.384)  Data (0.000)  Loss 0.0033 (0.0033)  MAEs:  e 0.051 (0.051)  f 0.005 (0.005)  s 0.003 (0.003)  m 0.015 (0.015)  \n",
      "Epoch: [4][6/12]\tTime (0.384)  Data (0.000)  Loss 0.0016 (0.0033)  MAEs:  e 0.035 (0.051)  f 0.005 (0.005)  s 0.002 (0.002)  m 0.012 (0.014)  \n",
      "Epoch: [4][12/12]\tTime (0.351)  Data (0.000)  Loss 0.0011 (0.0038)  MAEs:  e 0.033 (0.054)  f 0.004 (0.005)  s 0.002 (0.002)  m 0.014 (0.014)  \n",
      "*   e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n",
      "---------Evaluate Model on Test Set---------------\n",
      "**  e_MAE (0.056) \tf_MAE (0.005) \ts_MAE (0.003) \tm_MAE (0.015) \t\n"
     ]
    }
   ],
   "source": [
    "trainer.train(train_loader, val_loader, test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29",
   "metadata": {},
   "source": [
    "After training, the trained model can be found in the directory of today's date. Or it can be accessed by:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = trainer.model\n",
    "best_model = trainer.best_model  # best model based on validation energy MAE"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31",
   "metadata": {},
   "source": [
    "## Extras 1: GGA / GGA+U compatibility\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32",
   "metadata": {},
   "source": [
    "### Q: Why and when do you care about this?\n",
    "\n",
    "**When**: If you want to fine-tune the pretrained CHGNet with your own GGA+U VASP calculations, and you want to keep your VASP energy compatible to the pretrained dataset. In case your dataset is so large that the pretrained knowledge does not matter to you, you can ignore this.\n",
    "\n",
    "**Why**: CHGNet is trained on both GGA and GGA+U calculations from Materials Project. And there has been developed methods in solving the compatibility between GGA and GGA+U calculations which makes the energies universally applicable for cross-chemistry comparison and phase-diagram constructions. Please refer to:\n",
    "\n",
    "https://journals.aps.org/prb/abstract/10.1103/PhysRevB.84.045115\n",
    "\n",
    "Below we show an example to apply the compatibility.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The raw total energy from VASP of LMO is: -58.97 eV\n"
     ]
    }
   ],
   "source": [
    "# Imagine this is the VASP raw energy\n",
    "vasp_raw_energy = -58.97\n",
    "\n",
    "print(f\"The raw total energy from VASP of LMO is: {vasp_raw_energy} eV\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34",
   "metadata": {},
   "source": [
    "You can look for the energy correction applied to each element in :\n",
    "\n",
    "https://github.com/materialsproject/pymatgen/blob/v2023.2.28/pymatgen/entries/MP2020Compatibility.yaml\n",
    "\n",
    "Here LiMnO2 applies to both Mn in transition metal oxides correction and oxide correction.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35",
   "metadata": {},
   "source": [
    "To demystify `MaterialsProject2020Compatibility`, basically all that's happening is:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The corrected total energy after MP2020 = -65.05 eV\n"
     ]
    }
   ],
   "source": [
    "Mn_correction_in_TMO = -1.668\n",
    "oxide_correction = -0.687\n",
    "_, num_Mn, num_O = lmo.composition.values()\n",
    "\n",
    "\n",
    "corrected_energy = (\n",
    "    vasp_raw_energy + num_Mn * Mn_correction_in_TMO + num_O * oxide_correction\n",
    ")\n",
    "print(f\"The corrected total energy after MP2020 = {corrected_energy:.4} eV\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37",
   "metadata": {},
   "source": [
    "You can also apply the `MaterialsProject2020Compatibility` through pymatgen\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The total energy of LMO after MP2020Compatibility correction = -62.31 eV\n"
     ]
    }
   ],
   "source": [
    "from pymatgen.entries.compatibility import MaterialsProject2020Compatibility\n",
    "from pymatgen.entries.computed_entries import ComputedStructureEntry\n",
    "\n",
    "params = {\"hubbards\": {\"Mn\": 3.9, \"O\": 0, \"Li\": 0}, \"run_type\": \"GGA+U\"}\n",
    "\n",
    "cse = ComputedStructureEntry(lmo, vasp_raw_energy, parameters=params)\n",
    "\n",
    "MaterialsProject2020Compatibility(check_potcar=False).process_entries(cse)\n",
    "print(\n",
    "    f\"The total energy of LMO after MP2020Compatibility correction = {cse.energy:.4} eV\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39",
   "metadata": {},
   "source": [
    "Now use this corrected energy as labels to tune CHGNet, you're good to go!\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40",
   "metadata": {},
   "source": [
    "## Extras 2: AtomRef\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41",
   "metadata": {},
   "source": [
    "### Q: Why and when do you care about this?\n",
    "\n",
    "**When**: When you fine tune CHGNet to DFT labels that are incompatible with Materials Project, like r2SCAN functional, or other DFTs like Gaussian or QE. The large shifts in elemental energy is not of our interest and should be reconciled. For example, Li has -0.95 eV/atom in GGA (https://next-gen.materialsproject.org/materials/mp-135/tasks/mp-990455) and -1.17 eV/atom in R2SCAN (https://next-gen.materialsproject.org/materials/mp-135/tasks/mp-1943895)\n",
    "\n",
    "**Why**: The GNN learns the interaction between the atoms and the composition model (AtomRef) in CHGNet is used to normalize the elemental energy contribution, similar to a formation-energy-like calculation. During fine-tuning, we want to keep the most of knowledge unchanged in the GNN and allow the AtomRef to shift for the elemental energy change. So that the finetuning on the graph layers can be focused on energy contribution from atom-atom interaction instead of meaningless atom reference energies.\n",
    "\n",
    "Below I will show an example to fit the AtomRef layer:\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42",
   "metadata": {},
   "source": [
    "### A quick and easy way to turn on training of AtomRef in the trainer (this is by default off):\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Begin Training: using cpu device\n",
      "training targets: efsm\n",
      "Epoch: [0][1/12]\tTime (0.475)  Data (0.001)  Loss 0.0028 (0.0028)  MAEs:  e 0.047 (0.047)  f 0.005 (0.005)  s 0.003 (0.003)  m 0.014 (0.014)  \n",
      "Epoch: [0][6/12]\tTime (0.379)  Data (0.000)  Loss 0.0027 (0.0037)  MAEs:  e 0.046 (0.053)  f 0.005 (0.005)  s 0.002 (0.002)  m 0.015 (0.014)  \n",
      "Epoch: [0][12/12]\tTime (0.359)  Data (0.000)  Loss 0.0010 (0.0038)  MAEs:  e 0.030 (0.054)  f 0.005 (0.005)  s 0.003 (0.002)  m 0.012 (0.014)  \n",
      "*   e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n",
      "Epoch: [1][1/12]\tTime (0.417)  Data (0.000)  Loss 0.0011 (0.0011)  MAEs:  e 0.027 (0.027)  f 0.004 (0.004)  s 0.002 (0.002)  m 0.015 (0.015)  \n",
      "Epoch: [1][6/12]\tTime (0.359)  Data (0.000)  Loss 0.0049 (0.0040)  MAEs:  e 0.062 (0.056)  f 0.005 (0.005)  s 0.003 (0.002)  m 0.015 (0.015)  \n",
      "Epoch: [1][12/12]\tTime (0.351)  Data (0.000)  Loss 0.0054 (0.0038)  MAEs:  e 0.073 (0.054)  f 0.004 (0.005)  s 0.002 (0.002)  m 0.013 (0.014)  \n",
      "*   e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n",
      "Epoch: [2][1/12]\tTime (0.368)  Data (0.000)  Loss 0.0027 (0.0027)  MAEs:  e 0.043 (0.043)  f 0.005 (0.005)  s 0.003 (0.003)  m 0.016 (0.016)  \n",
      "Epoch: [2][6/12]\tTime (0.388)  Data (0.000)  Loss 0.0042 (0.0034)  MAEs:  e 0.056 (0.051)  f 0.005 (0.005)  s 0.003 (0.003)  m 0.014 (0.015)  \n",
      "Epoch: [2][12/12]\tTime (0.354)  Data (0.000)  Loss 0.0033 (0.0038)  MAEs:  e 0.054 (0.054)  f 0.004 (0.005)  s 0.003 (0.002)  m 0.013 (0.014)  \n",
      "*   e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n",
      "Epoch: [3][1/12]\tTime (0.351)  Data (0.000)  Loss 0.0032 (0.0032)  MAEs:  e 0.048 (0.048)  f 0.005 (0.005)  s 0.003 (0.003)  m 0.014 (0.014)  \n",
      "Epoch: [3][6/12]\tTime (0.371)  Data (0.000)  Loss 0.0046 (0.0035)  MAEs:  e 0.064 (0.052)  f 0.005 (0.005)  s 0.002 (0.003)  m 0.016 (0.014)  \n",
      "Epoch: [3][12/12]\tTime (0.351)  Data (0.000)  Loss 0.0088 (0.0038)  MAEs:  e 0.093 (0.054)  f 0.005 (0.005)  s 0.002 (0.002)  m 0.016 (0.014)  \n",
      "*   e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n",
      "Epoch: [4][1/12]\tTime (0.376)  Data (0.000)  Loss 0.0048 (0.0048)  MAEs:  e 0.066 (0.066)  f 0.005 (0.005)  s 0.002 (0.002)  m 0.013 (0.013)  \n",
      "Epoch: [4][6/12]\tTime (0.375)  Data (0.000)  Loss 0.0017 (0.0036)  MAEs:  e 0.030 (0.053)  f 0.005 (0.005)  s 0.003 (0.002)  m 0.016 (0.014)  \n",
      "Epoch: [4][12/12]\tTime (0.351)  Data (0.000)  Loss 0.0006 (0.0038)  MAEs:  e 0.020 (0.054)  f 0.005 (0.005)  s 0.003 (0.002)  m 0.013 (0.014)  \n",
      "*   e_MAE (0.028) \tf_MAE (0.006) \ts_MAE (0.002) \tm_MAE (0.015) \t\n",
      "---------Evaluate Model on Test Set---------------\n",
      "**  e_MAE (0.056) \tf_MAE (0.005) \ts_MAE (0.003) \tm_MAE (0.015) \t\n"
     ]
    }
   ],
   "source": [
    "trainer.train(train_loader, val_loader, test_loader, train_composition_model=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44",
   "metadata": {},
   "source": [
    "### The more regorous way is to solve for the per-atom contribution by linear regression in your fine-tuning dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The pretrained Atom_Ref (per atom reference energy):\n",
      "Parameter containing:\n",
      "tensor([[ -3.4431,  -0.1279,  -2.8300,  -3.4737,  -7.4946,  -8.2354,  -8.1611,\n",
      "          -8.3861,  -5.7498,  -0.0236,  -1.7406,  -1.6788,  -4.2833,  -6.2002,\n",
      "          -6.1315,  -5.8405,  -3.8795,  -0.0703,  -1.5668,  -3.4451,  -7.0549,\n",
      "          -9.1465,  -9.2594,  -9.3514,  -8.9843,  -8.0228,  -6.4955,  -5.6057,\n",
      "          -3.4002,  -0.9217,  -3.2499,  -4.9164,  -4.7810,  -5.0191,  -3.3316,\n",
      "           0.5130,  -1.4043,  -3.2175,  -7.4994,  -9.3816, -10.4386,  -9.9539,\n",
      "          -7.9555,  -8.5440,  -7.3245,  -5.2771,  -1.9014,  -0.4034,  -2.6002,\n",
      "          -4.0054,  -4.1156,  -3.9928,  -2.7003,   2.2170,  -1.9671,  -3.7180,\n",
      "          -6.8133,  -7.3502,  -6.0712,  -6.1699,  -5.1471,  -6.1925, -11.5829,\n",
      "         -15.8841,  -5.9994,  -6.0798,  -5.9513,  -6.0400,  -5.9773,  -2.5091,\n",
      "          -6.0767, -10.6666, -11.8761, -11.8491, -10.7397,  -9.6100,  -8.4755,\n",
      "          -6.2070,  -3.0337,   0.4726,  -1.6425,  -3.1295,  -3.3328,  -0.1221,\n",
      "          -0.3448,  -0.4364,  -0.1661,  -0.3680,  -4.1869,  -8.4233, -10.0467,\n",
      "         -12.0953, -12.5228, -14.2530]], requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "print(\"The pretrained Atom_Ref (per atom reference energy):\")\n",
    "for param in chgnet.composition_model.parameters():\n",
    "    print(param)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46",
   "metadata": {},
   "outputs": [],
   "source": [
    "# A list of structures / graphs\n",
    "structures = [\n",
    "    lmo,\n",
    "    Structure(\n",
    "        species=[\"Li\", \"Mn\", \"Mn\", \"O\", \"O\", \"O\"],\n",
    "        lattice=np.random.rand(3, 3),\n",
    "        coords=np.random.rand(6, 3),\n",
    "    ),\n",
    "    Structure(\n",
    "        species=[\"Li\", \"Li\", \"Mn\", \"O\", \"O\", \"O\"],\n",
    "        lattice=np.random.rand(3, 3),\n",
    "        coords=np.random.rand(6, 3),\n",
    "    ),\n",
    "    Structure(\n",
    "        species=[\"Li\", \"Mn\", \"Mn\", \"O\", \"O\", \"O\", \"O\"],\n",
    "        lattice=np.random.rand(3, 3),\n",
    "        coords=np.random.rand(7, 3),\n",
    "    ),\n",
    "]\n",
    "\n",
    "# A list of energy_per_atom values (random values here)\n",
    "energies_per_atom = [5.5, 6, 4.8, 5.6]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "We initialize another identical AtomRef layers\n",
      "tensor([[-3.4431, -0.1279, -2.8300]], grad_fn=<SliceBackward0>)\n"
     ]
    }
   ],
   "source": [
    "from chgnet.model.composition_model import AtomRef\n",
    "\n",
    "print(\"We initialize another identical AtomRef layers\")\n",
    "new_atom_ref = AtomRef(is_intensive=True)\n",
    "new_atom_ref.initialize_from_MPtrj()\n",
    "for param in new_atom_ref.parameters():\n",
    "    print(param[:, :3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After refitting, the AtomRef looks like:\n",
      "Parameter containing:\n",
      "tensor([[ 0.0000e+00,  0.0000e+00,  4.2667e+00, -3.3299e-15,  0.0000e+00,\n",
      "          0.0000e+00,  0.0000e+00,  2.9999e+00,  0.0000e+00,  0.0000e+00,\n",
      "          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,\n",
      "          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,\n",
      "          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.1467e+01,\n",
      "          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,\n",
      "          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,\n",
      "          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,\n",
      "          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,\n",
      "          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,\n",
      "          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,\n",
      "          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,\n",
      "          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,\n",
      "          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,\n",
      "          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,\n",
      "          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,\n",
      "          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,\n",
      "          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,\n",
      "          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]],\n",
      "       requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "# Solve linear regression to find the per atom contribution in your fine-tuning dataset\n",
    "\n",
    "new_atom_ref.fit(structures, energies_per_atom)\n",
    "print(\"After refitting, the AtomRef looks like:\")\n",
    "for param in new_atom_ref.parameters():\n",
    "    print(param)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
